Merge Gemma recipe with full finetune #668
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context
The primary reason Gemma had its own recipe was due to weight tying, where the output projection = token embedding weights. This replicates the behavior of
ReversibleEmbedding
in Keras where you can use the embedding weight to project back from output dim to input dim. This also had implications in FSDP wrapping and initializing on meta device, you can see #630 and #616 for more discussion on that.We can actually achieve the same "weight tying" by getting rid of the output projection altogether and using the embedding weight directly for the output (shout-out @pbontrager):
output = F.linear(h, self.tok_embeddings.weight).float()
This is more akin to how its done in
GemmaCausalLM
in Keras, where there's no output projection and the token embedding weight is used directly.Changelog
GemmaTransformerDecoder
gemma_full_finetune_distributed.py
recipeload_shared_weights_utils
andsave_shared_weights_utils
torchtune/models/
Test plan
This run had nearly equivalent loss values to the gemma recipe on main:
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/2B_full max_steps_per_epoch=5
tune run --nnodes 1 --nproc_per_node 4 gemma_full_finetune_distributed --config gemma/2B_full max_steps_per_epoch=5
Comparison with HF implementation: